{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b33a915d-cc1d-4102-a2ee-159c02e6c579",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "57c0a1b8-e339-4f6b-941e-7af7b902de7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from langchain import PromptTemplate, LLMChain\n",
    "from kg_rag.utility import *\n",
    "from tqdm import tqdm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2672548d-7d25-4f3c-94d1-d19206049076",
   "metadata": {},
   "outputs": [],
   "source": [
    "QUESTION_PATH = config_data[\"MCQ_PATH\"]\n",
    "SYSTEM_PROMPT = system_prompts[\"MCQ_QUESTION\"]\n",
    "CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n",
    "QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
    "QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
    "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
    "NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
    "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
    "SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
    "SAVE_PATH = config_data[\"SAVE_RESULTS_PATH\"]\n",
    "\n",
    "MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B'\n",
    "BRANCH_NAME = 'main'\n",
    "CACHE_DIR = config_data[\"LLM_CACHE_DIR\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c753b053-be44-4ddb-8d55-3bf434428954",
   "metadata": {},
   "outputs": [],
   "source": [
    "INSTRUCTION = \"Context:\\n\\n{context} \\n\\nQuestion: {question}\"\n",
    "\n",
    "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
    "embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
    "node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n",
    "edge_evidence = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f18c9efb-556c-4b37-8b00-e06a73a19f86",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00,  6.66s/it]\n",
      "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)               \n",
    "template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)\n",
    "prompt = PromptTemplate(template=template, input_variables=[\"context\", \"question\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0370d703-4e18-4c78-9e9a-2030b498253e",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_chain = LLMChain(prompt=prompt, llm=llm)    \n",
    "question_df = pd.read_csv(QUESTION_PATH)  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "275f4171-3be7-46ca-bf16-18160ce72f3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Out of the given list, which Gene is associated with psoriasis and Takayasu's arteritis. Given list is: SHTN1, HLA-B,  SLC14A2,  BTBD9,  DTNB\""
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_df.iloc[0].text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cc5a65fb-6bd3-4948-84e5-f404af83d3f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4135 > 2048). Running this sequence through the model will result in indexing errors\n",
      "This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.\n",
      "0it [04:19, ?it/s]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "answer_list = []\n",
    "question_df = question_df.sample(50, random_state=40)\n",
    "for index, row in tqdm(question_df.iterrows()):\n",
    "    question = row[\"text\"]\n",
    "    context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)\n",
    "    output = llm_chain.run(context=context, question=question)\n",
    "    print(output)\n",
    "    input('press enter')\n",
    "    answer_list.append((row[\"text\"], row[\"correct_node\"], output))\n",
    "answer_df = pd.DataFrame(answer_list, columns=[\"question\", \"correct_answer\", \"llm_answer\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94eb325d-17d4-4013-907d-7a38dabaea56",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
